% goal here is to generate info on what is happening in the stim sequence:
% generates stiminfo, which is a 3xtotalFrames array
% first row is the number of the stim that's happening (i.e. first stim is
% 1, second stim is 2...)
%
% second row is the index of the stim position on the screen
%
% third row indicates if the stim is on or not for that specific frame; 0
% means it is off, but then any further numbers indicate the kind of stim
% that is being presented (as indicated by 'movindicator'; that corresponds
% to the movies in 'mov')
%
% unique positions actually says where on the screen the stim is happening
%
% movieparams tells you what sort of stim was being presented: if it is a bp
% filtered noise movie, it will give the parameters for that, if it is a
% blinking movie it will give the parameters for that...
%
% updates from getStimInfo_v4.3: adds ability to process moving stimuli;
% note this changes the formatting of uniqueID to be:
% each row is a separate unique stimulus
% first column is starting x position
% second column is starting y position
% third column is ending x position
% fourth column is ending y position
% fifth column is the stim type
% sixth column is the size of the stim in radians
% seventh column is the stim duration
% these repeat if there are multiple simultaneous stimuli

function [stiminfo,uniqueID,movieparams,stimID] = getStimInfo_v4_4(movname,metadataname,results)

dointerleave = 0;       %keeping the first channel of each movie for analysis

load(metadataname, 'rectPositions','cycleTime','stimparams','mov','movindicator')

%% collect data about the movies being presented
% movie params will be a 2xnumMovies array, with the top row being the
% class of movie (0 = blank, 1 = blinking box, 2 = bp filtered noise, 3 = constant sized circle (blinking), 4 = constant sized circle (bpf noise))
% second row being the size of the stimulus (in pixels)
standardMovIndicator = zeros(size(movindicator));   %standardize movindicator so that the # always corresponds to the same movie type; 
movsize = zeros(size(standardMovIndicator));
movduration = zeros(size(movindicator));
movieparams = nan(2,length(stimparams));

for n = 1:size(movieparams,2)
    if isfield(stimparams{n},'stimtype') && strcmp(stimparams{n}.stimtype,'blank')
        movieparams(1,n) = 0;
        standardMovIndicator(movindicator == n) = 0;
    elseif isfield(stimparams{n},'circlesize') && size(rectPositions,4) == 1        %regular blinking circle
        if isfield(stimparams{n},'stimtype') && strcmp(stimparams{n}.stimtype,'bpFilteredNoise')
            movieparams(1,n) = 4;
            standardMovIndicator(movindicator == n) = 3;
        elseif isfield(stimparams{n},'stimtype') && isfield(stimparams{n}, 'blinkrate')
            movieparams(1,n) = 3;
            standardMovIndicator(movindicator == n) = 3;
        end
        movieparams(2,n) = stimparams{n}.circlesize;
        movsize(movindicator == n) = movieparams(2,n);
        movduration = stimparams{n}.totalStimLength;
    elseif isfield(stimparams{n},'circlesize') && size(rectPositions,4) > 1     %moving blinking circle
        movieparams(1,n) = 5;
        movieparams(2,n) = stimparams{n}.circlesize;
        standardMovIndicator(movindicator == n) = 5;
        movsize(movindicator == n) = movieparams(2,n);
        movduration = stimparams{n}.totalStimLength;
    else
        positivePixels = round(sqrt(stimparams{n}.imsize^2-sum(sum((mov{n}(:,:,1) == 0.5))))*stimparams{n}.imageMag);
        cursize = stimparams{n}.imsize*stimparams{n}.imageMag;
        if isfield(stimparams{n},'switchrate') && ~isnan(stimparams{n}.switchrate) && cursize == positivePixels
            movieparams(1,n) = 1;
            movieparams(2,n) = cursize;
            standardMovIndicator(movindicator == n) = 1;
            movsize(movindicator == n) = movieparams(2,n);
        elseif isfield(stimparams{n},'switchrate') && isnan(stimparams{n}.switchrate) 
            %blank movie
            movieparams(:,n) = 0;
            standardMovIndicator(movindicator == n) = 0;
        elseif isfield(stimparams{n},'maxSpatFreq') && cursize == positivePixels
            %full image bandpass filtered noise
            movieparams(1,n) = 2;
            movieparams(2,n) = cursize;
            standardMovIndicator(movindicator == n) = 2;
            movsize(movindicator == n) = movieparams(2,n);
        elseif isfield(stimparams{n},'frequency') && isfield(stimparams{n},'phase')
            %
            movieparams(1,n) = 3;
            movieparams(2,n) = cursize;
            standardMovIndicator(movindicator == n) = 3;
            movsize(movindicator == n) = movieparams(2,n);
        elseif isfield(stimparams{n},'maxSpatFreq') && positivePixels ~= cursize
            movieparams(1,n) = 4;
            movieparams(2,n) = cursize-positivePixels;
            standardMovIndicator(movindicator == n) = 4;
            movsize(movindicator == n) = movieparams(2,n);
        end
    end
end

framesperloop = zeros(size(movname));

try
        %moviemetadata = getSImetadata(str2char(movname(1)));
        moviemetadata = getSImetadata(convertStringsToChars(movname(1)));

    if strcmp('\\research.files.med.harvard.edu\neurobio\GU LAB\TWO-PHOTON MICROSCOPY\visualStim\BI30-CAG-gcamp8m\4A\S2',cd)
        numloops = 8;framesperloop = 190;fps = 10;
        disp('PATCH JOB TO WORKAROUND METADATA ERROR')
    elseif strcmp('\\research.files.med.harvard.edu\neurobio\GU LAB\TWO-PHOTON MICROSCOPY\visualStim\cx40fl-fl_cx37fl-fl_bmxCreER\cohort 2\576\S6',cd)
        numloops = 10;framesperloop = 190;fps = 10;
        disp('PATCH JOB TO WORKAROUND METADATA ERROR')
    else
        fps = moviemetadata.framerate/moviemetadata.frameaveraging;
        framesperloop = moviemetadata.totalframes/moviemetadata.frameaveraging;
        numloops = moviemetadata.numLoops;

        %totalframes = moviemetadata.totalframes/moviemetadata.frameaveraging;
    end
catch
    disp('not recognized as SI tif')
    fps = 10;
    movinfo = imfinfo(movname(1));
    framesperloop = length(movinfo);
    numloops = length(movname);
    if dointerleave > 0 
        totalframes = totalframes/2;
    end
end

stimframes = round(results.preStimTime*fps +1):round(results.preStimTime*fps + cycleTime*fps);
%stimframes = round((pauseTime*fps) + 1):round((pauseTime*fps) + cycleTime*fps);

temp = zeros(2*size(standardMovIndicator,1),size(standardMovIndicator,2));
temp(1:2:end) = standardMovIndicator;
temp(2:2:end) = movsize;
temp = reshape(temp,size(temp,1),size(rectPositions,2),size(temp,2));
%temp2 = reshape(standardMovIndicator,1,size(standardMovIndicator,1),size(standardMovIndicator,2));
if ndims(rectPositions)>3 && size(rectPositions,4) > 1
    %for moving stimuli, uniquely identify them  by the first position,
    %last position, and time taken to move  between them
    %
    % this makes stimid the following:
    % first dimension, first index of fourth dimension = movie type
    % second dimension, first index of fourth dimension = the different SIMULTANEOUSLY presented rectangles
    % third dimension, first index of fourth dimension = different stimuli
    % presentations
    %
    % second index of fourth dimension = starting x,y position
    % third index of fourth dimension = ending x,y position
    % fourth index of fourth dimension = stim duration
    temp2 = nan(size(rectPositions,1), size(rectPositions,2),size(rectPositions,3),3);
    temp2(:,:,:,1:2) = rectPositions(:,:,:,[1,end]);
    temp2(:,:,:,3) = movduration;
    stimid = cat(4,temp,temp2); 
else
    stimid = cat(1,rectPositions,temp);
end

if ndims(stimid) == 3
    % reshaping these variables for ease of use in downstream functions, so
    % downstream these will be formatted as first column is the x-position,
    % second column is the y-position, third column is the movie type, then
    % that repeats for additional rectangles that are presented simultaneously.
    % Each row is a different stim presentation
    stimid2 = reshape(stimid,size(stimid,1)*size(stimid,2),size(stimid,3));
    stimid2 = stimid2';
    if isfield(stimparams{1},'totalStimLength')
        totalStimLength = stimparams{1}.totalStimLength;
    else
        totalStimLength = round(length(stimframes)/fps,1);
    end
    numboxes = round(size(stimid2,2)/4);
    temp = nan(size(stimid2,1),7*numboxes);
    for n = 1:numboxes
        curwriteind = (n-1)*7+1;
        curreadind = (n-1)*4+1;
        temp(:,curwriteind:(curwriteind+2)) = stimid2(:,curreadind:(curreadind+2));
        curwriteind = (n-1)*7+4;
        curreadind = (n-1)*4+1;
        temp(:,curwriteind:(curwriteind+1)) = stimid2(:,curreadind:(curreadind+1));
        curwriteind = (n-1)*7+6;
        curreadind = (n-1)*4+4;
        temp(:,curwriteind) = stimid2(:,curreadind);
        curwriteind = (n-1)*7+7;
        temp(:,curwriteind) = totalStimLength;
    end
    stimid2 = temp;
    uniqueID = unique(stimid2,'rows');
    numUniqueStims = size(uniqueID,1);
elseif ndims(stimid) == 4
    % same as above, but now as follows:
    % start x, start y, movie type, end x, end y, physical size of stim in radians, total stim duration
    stimid2 = nan(size(stimid,3),size(stimid,2)*7);
    for n = 1:size(stimid,2)
        curind = (n-1)*6+1;
        stimid2(:,curind) = stimid(1,n,:,2);    %start x position
        stimid2(:,curind+1) = stimid(2,n,:,2);  %start y position
        stimid2(:,curind+2) = stimid(1,n,:,1);  %movie type
        stimid2(:,curind+3) = stimid(1,n,:,3);  %end x position
        stimid2(:,curind+4) = stimid(2,n,:,3);  %end y position
        stimid2(:,curind+5) = stimid(2,n,:,1);  %size of stimulus
        stimid2(:,curind+6) = stimid(1,n,:,4);  %duration of stimulus
    end
    uniqueID = unique(stimid2,'rows');
    numUniqueStims = size(uniqueID,1);
end

stimID = stimid2;       %this is the one that will be saved and its format is easy to process

%% add in a bit of a hack here because sometimes averaging is forgotten on the aquisition side but it is done on the processing side before this:
if isfield(results,'rawTracking')
    temp = size(results.rawTracking,2);
    if framesperloop ~= temp
        fps = fps/(framesperloop/temp);
        stimframes = round(results.preStimTime*fps +1):round(results.preStimTime*fps + cycleTime*fps);
        %stimframes = round(stimframes/(framesperloop/temp));
        %stimframes = unique(stimframes);
        framesperloop = temp;
    end
elseif isfield(results,'minoraxisFit')
    if isfield(results,'ballTracking')
        numframes = max(size(results(1).ballTracking));            %this is relying on the largest dimension of this array being the number of frames***
    else
        numframes = max(size(results(1).params));            %this is relying on the largest dimension of this array being the number of frames***
    end
    for n = 1:length(results)
        if size(results(n).minoraxisFit,2) == numframes
            minoraxisFit = zeros(size(results(n).minoraxisFit,2),size(results(n).minoraxisFit,1),size(results(n).minoraxisFit,3));
            results(n).minoraxisFit = minoraxisFit;
        end
    end
    temp = size(results.minoraxisFit,1);
    if framesperloop ~= temp
        fps = fps/(framesperloop/temp);
        stimframes = round(results.preStimTime*fps +1):round(results.preStimTime*fps + cycleTime*fps);
        %stimframes = round(stimframes/(framesperloop/temp));
        %stimframes = unique(stimframes);
        framesperloop = temp;
    end
end




stiminfo = zeros(3,framesperloop,numloops);
for n = 1:numloops
    stiminfo(1,stimframes,n) = n;
end
for n = 1:numUniqueStims
    temp = ismember(stimid2,uniqueID(n,:),'rows');
    ind = find(temp == 1);
    for k = 1:length(ind)
        stiminfo(2,stimframes,ind(k)) = n;
        stiminfo(3,stimframes,ind(k)) = 1;
    end
end




% hypothetical use of this code for looking at some measure in each frame:
%stimAggregate = cell(1,numStimPositions);
%measurementVector = randi(1,totalframes);
%for n = 1:numStimPositions
%    temp = zeros(numframes,numTotalStims/numStimPositions);
%    temp(:) = measurementVector(stiminfo(2,:)==n);
%    stimAggregate{n} = temp;
%    meanTraj = mean(stimAggregate{n},2);
%    meanStimValue = mean(stimAggregate{n}(stimframes,:),2);
%end
        
        
        